//
// Copyright Coinbase, Inc. All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

package elgamal

import (
	"crypto/aes"
	"crypto/cipher"
	"fmt"
	"git.sr.ht/~sircmpwn/go-bare"
	"github.com/coinbase/kryptology/internal"
	"github.com/coinbase/kryptology/pkg/core"
	"github.com/coinbase/kryptology/pkg/core/curves"
	"math/big"
)

type decryptionKeyMarshal struct {
	X     []byte `bare:"x"`
	Curve string `bare:"curve"`
}

// DecryptionKey decrypts verifiable ciphertext and verifies proofs
type DecryptionKey struct {
	x curves.Scalar
}

// EncryptionKey returns the corresponding encryption key for this decryption key
func (dk DecryptionKey) EncryptionKey() *EncryptionKey {
	return &EncryptionKey{
		value: dk.x.Point().Generator().Mul(dk.x),
	}
}

// MarshalBinary serializes a key to bytes
func (dk DecryptionKey) MarshalBinary() ([]byte, error) {
	tv := new(decryptionKeyMarshal)
	tv.X = dk.x.Bytes()
	tv.Curve = dk.x.Point().CurveName()
	return bare.Marshal(tv)
}

// UnmarshalBinary deserializes a key from bytes
func (dk *DecryptionKey) UnmarshalBinary(data []byte) error {
	tv := new(decryptionKeyMarshal)
	err := bare.Unmarshal(data, tv)
	if err != nil {
		return err
	}
	curve := curves.GetCurveByName(tv.Curve)
	if curve == nil {
		return fmt.Errorf("unknown curve")
	}
	x, err := curve.Scalar.SetBytes(tv.X)
	if err != nil {
		return err
	}
	dk.x = x
	return nil
}

// Decrypt returns the resulting point from El-Gamal decryption.
// No checking is performed whether the ciphertext has been modified.
// M = H * m = C2 - C1 * x
func (dk *DecryptionKey) Decrypt(cipherText *HomomorphicCipherText) curves.Point {
	if cipherText == nil {
		return nil
	}
	return cipherText.c2.Sub(cipherText.c1.Mul(dk.x))
}

// VerifiableDecrypt the ciphertext. This performs verifiable decryption
// such that the decrypted data is checked against El-Gamal C2 value.
// If the plaintext does not match, an error is returned
// been generated by Encrypt
func (dk DecryptionKey) VerifiableDecrypt(cipherText *CipherText) ([]byte, curves.Scalar, error) {
	msgBytes, msgScalar, rhs, err := dk.decryptData(cipherText)
	if err != nil {
		return nil, nil, err
	}
	h := dk.EncryptionKey().value.Generator()
	lhs := h.Mul(msgScalar)
	if !lhs.Equal(rhs) {
		return nil, nil, fmt.Errorf("ciphertext mismatch")
	}
	return msgBytes, msgScalar, nil
}

// VerifiableDecryptWithDomain the ciphertext. This performs verifiable decryption
// such that the decrypted data is checked against El-Gamal C2 value.
// If the plaintext does not match, an error is returned
// The Domain component is meant for scenarios where `msg` is used in more
// than just one setting and should be contextualized. The ciphertext must have
// been generated by EncryptWithDomain
func (dk DecryptionKey) VerifiableDecryptWithDomain(domain []byte, cipherText *CipherText) ([]byte, curves.Scalar, error) {
	msgBytes, msgScalar, rhs, err := dk.decryptData(cipherText)
	if err != nil {
		return nil, nil, err
	}
	ek := dk.EncryptionKey()
	genBytes := append(domain, ek.value.ToAffineUncompressed()...)
	genBytes = append(genBytes, cipherText.nonce...)

	h := ek.value.Hash(genBytes)
	lhs := h.Mul(msgScalar)
	if !lhs.Equal(rhs) {
		return nil, nil, fmt.Errorf("ciphertext mismatch")
	}
	return msgBytes, msgScalar, nil
}

func (dk DecryptionKey) decryptData(cipherText *CipherText) ([]byte, curves.Scalar, curves.Point, error) {
	if cipherText == nil {
		return nil, nil, nil, internal.ErrNilArguments
	}
	if cipherText.c1 == nil || cipherText.c2 == nil || cipherText.nonce == nil || cipherText.aead == nil {
		return nil, nil, nil, internal.ErrNilArguments
	}
	// Have to check these because aesgcm will panic if not the correct length
	if len(cipherText.nonce) < 12 || len(cipherText.aead) < 16 {
		return nil, nil, nil, internal.ErrZeroValue
	}
	// r * Q
	t := cipherText.c1.Mul(dk.x)

	aeadKey, err := core.FiatShamir(new(big.Int).SetBytes(t.ToAffineCompressed()))
	if err != nil {
		return nil, nil, nil, err
	}
	block, err := aes.NewCipher(aeadKey)
	if err != nil {
		return nil, nil, nil, err
	}
	aesGcm, err := cipher.NewGCM(block)
	if err != nil {
		return nil, nil, nil, err
	}

	aad := cipherText.c1.ToAffineUncompressed()
	aad = append(aad, cipherText.c2.ToAffineUncompressed()...)
	// AAD = C1 || C2
	msgBytes, err := aesGcm.Open(nil, cipherText.nonce, cipherText.aead, aad)
	if err != nil {
		return nil, nil, nil, err
	}
	msg := dk.x.New(0)
	if cipherText.msgIsHashed {
		msg, err = msg.SetBytes(msgBytes)
		if err != nil {
			return nil, nil, nil, err
		}
	} else {
		msg = msg.Hash(msgBytes)
	}
	if err != nil {
		return nil, nil, nil, err
	}
	rhs := cipherText.c2.Sub(t)

	return msgBytes, msg, rhs, nil
}
